package xrsa

import (
	"context"
	"crypto"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/base64"
	"encoding/hex"
	"encoding/pem"
	"errors"
	"fmt"
)

//公钥： 加密，和 验签
type XPublicKey struct {
	publicKey             *rsa.PublicKey
	max_decryt_block_size int
	max_encryt_block_size int
}

func NewPemPublicKey(keyData []byte) (*XPublicKey, error) {
	keyPerm, _ := pem.Decode(keyData)
	if keyPerm == nil {
		return nil, errors.New("public key error")
	}

	key, err := x509.ParsePKIXPublicKey(keyPerm.Bytes)
	if err != nil {
		return nil, err
	}
	pubkey, ok := key.(*rsa.PublicKey)
	if !ok {
		return nil, err
	}
	decrySize := pubkey.N.BitLen() / 8
	encySize := decrySize - 11
	return &XPublicKey{
		publicKey:             pubkey,
		max_decryt_block_size: decrySize,
		max_encryt_block_size: encySize,
	}, nil
}

func (m *XPublicKey) EncryptEncode(ctx context.Context, plaintext []byte, encoder RsaCoder) ([]byte, string, error) {
	encyData, err := m.Encrypt(ctx, plaintext)
	if err != nil {
		return nil, "", err
	}
	return encyData, encoder.Bytes2string(encyData), nil
}

//公钥加密
func (m *XPublicKey) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) {
	inputLenth := len(plaintext)
	maxEncryBlock := m.max_encryt_block_size
	offset := 0
	leftLenth := inputLenth - offset
	outResult := make([]byte, 0)
	// 对数据分段加密
	for leftLenth > 0 {
		var data []byte
		var err error
		if leftLenth > maxEncryBlock {
			data, err = rsa.EncryptPKCS1v15(rand.Reader, m.publicKey, plaintext[offset:offset+maxEncryBlock])
		} else {
			data, err = rsa.EncryptPKCS1v15(rand.Reader, m.publicKey, plaintext[offset:])
		}
		if err != nil {
			return nil, err
		}
		offset = offset + maxEncryBlock
		leftLenth = inputLenth - offset
		outResult = append(outResult, data...)
	}

	return outResult, nil
}

//签名先解码之后进行验签
func (m *XPublicKey) DecodeVerify(ctx context.Context, src []byte, sign string, hash crypto.Hash, encoder RsaCoder) error {
	return m.Verify(ctx, src, encoder.String2bytes(sign), hash)
}

//验签
func (m *XPublicKey) Verify(ctx context.Context, src []byte, sign []byte, hash crypto.Hash) error {
	h := hash.New()
	h.Write(src)
	hashed := h.Sum(nil)
	return rsa.VerifyPKCS1v15(m.publicKey, hash, hashed, sign)
}

//私钥：解密 和 签名
type XPrivateKey struct {
	privateKey            *rsa.PrivateKey
	max_decryt_block_size int
	max_encryt_block_size int
}
type PKCSType int

const (
	PKCS1 PKCSType = iota
	PKCS8 PKCSType = iota
)

func NewPemPrivateKey(keyData []byte, pkcsType PKCSType) (*XPrivateKey, error) {
	keyPerm, _ := pem.Decode(keyData)
	if keyPerm == nil {
		return nil, errors.New("public key error")
	}
	var key interface{}
	var err error
	switch pkcsType {
	case PKCS1:
		key, err = x509.ParsePKCS1PrivateKey(keyPerm.Bytes)
	case PKCS8:
		key, err = x509.ParsePKCS8PrivateKey(keyPerm.Bytes)
	default:
		return nil, fmt.Errorf("not support pkcstype: %d", pkcsType)
	}
	if err != nil {
		return nil, err
	}
	priKey, ok := key.(*rsa.PrivateKey)
	if !ok {
		return nil, err
	}
	decrySize := priKey.N.BitLen() / 8
	encySize := decrySize - 11
	return &XPrivateKey{
		privateKey:            priKey,
		max_decryt_block_size: decrySize,
		max_encryt_block_size: encySize,
	}, nil
}

type RsaCoder interface {
	//加密之后的字节，编码成可读的字符串
	Bytes2string([]byte) string
	//可读字符串，解码成原始字节
	String2bytes(string) []byte
}

type HexRsaEncoder struct {
}

func (m *HexRsaEncoder) String2bytes(s string) []byte {
	b, _ := hex.DecodeString(s)
	return b
}
func (m *HexRsaEncoder) Bytes2string(b []byte) string {
	return hex.EncodeToString(b)
}

type Base64RsaEncoder struct {
}

func (m *Base64RsaEncoder) String2bytes(s string) []byte {
	b, _ := base64.StdEncoding.DecodeString(s)
	return b
}
func (m *Base64RsaEncoder) Bytes2string(b []byte) string {
	return base64.StdEncoding.EncodeToString(b)
}

//公钥加密，并且将加密后的字节编码成字符串

//私钥解密, 先将密文进行解码，再进行解密
func (m *XPrivateKey) DecodeDecrypt(ctx context.Context, ciphertext string, decoder RsaCoder) ([]byte, error) {
	return m.Decrypt(ctx, decoder.String2bytes(ciphertext))
}

//私钥解密
func (m *XPrivateKey) Decrypt(ctx context.Context, cipherData []byte) ([]byte, error) {
	inputLenth := len(cipherData)
	maxDecryBlock := m.max_decryt_block_size
	offset := 0
	leftLenth := inputLenth - offset
	outResult := make([]byte, 0)
	// 对数据分段加密
	for leftLenth > 0 {
		var data []byte
		var err error
		if leftLenth > maxDecryBlock {
			data, err = rsa.DecryptPKCS1v15(rand.Reader, m.privateKey, cipherData[offset:offset+maxDecryBlock])
		} else {
			data, err = rsa.DecryptPKCS1v15(rand.Reader, m.privateKey, cipherData[offset:])
		}
		if err != nil {
			fmt.Println("decry error: inputlen:", inputLenth, ", offset:", offset, err)
			return nil, err
		}
		offset = offset + maxDecryBlock
		leftLenth = inputLenth - offset
		outResult = append(outResult, data...)
	}
	return outResult, nil
}

func (m *XPrivateKey) Sign(src []byte, hash crypto.Hash) ([]byte, error) {
	h := hash.New()
	h.Write(src)
	hashed := h.Sum(nil)
	return rsa.SignPKCS1v15(rand.Reader, m.privateKey, hash, hashed)
}
func (m *XPrivateKey) SignEncode(src []byte, hash crypto.Hash, encoder RsaCoder) (string, error) {
	data, err := m.Sign(src, hash)
	return encoder.Bytes2string(data), err
}
